"""
@Author: Bright
@File: custom_chat_session.py
@Time: 2025/10/7
@Desc:
"""
import bs4
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.globals import set_debug
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_ollama import OllamaLLM
from langchain_text_splitters import RecursiveCharacterTextSplitter

from model.MyOllamaEmbeddings import MyOllamaEmbeddings

# set_debug(True)  ## 打印调试日志
# 创建一个webBaseLoader对象，用于从指定网址加载文档
strainer = bs4.SoupStrainer(["p", "h1", "h2", "h3", "div.content", "article"])
loader = WebBaseLoader(
    # web_path="https://lilianweng.github.io/posts/2023-06-23-agent",
    web_path="https://baike.baidu.com/item/%E4%BB%BB%E5%8A%A1%E5%88%86%E8%A7%A3%E6%B3%95",
    bs_kwargs=dict(
        parse_only=bs4.SoupStrainer(
            # class_=["post-content", "post-title", "post-header"]
        )
    )
)
docs = loader.load()  # 加载文档
# print(docs)

# 创建一个RecursiveCharacterTextSplitter对象，用于将文档拆分成较小的文本块
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0)
splits = text_splitter.split_documents(docs)
# 创建一个Chroma 对象，用户存储文本块的向量表示
embeddings = MyOllamaEmbeddings(model="nomic-embed-text")
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
# 将向量存储转换为检索器
retriever = vectorstore.as_retriever()

# 定义系统提示词模板
system_prompt = (
    "您是一个用于问答任务的助手。"
    "使用如下检索的上下文片段来回答问题。"
    "如果您不知道答案，请说您不知道。"
    "最多用三句话，保持回答简洁。"
    "\n\n"
    "{context}"
)
# 创建一个ChatPromptTemplate对象，用户生成提示词
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)
# 创建一个带有聊天历史文化记录的提示词模板
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
# 创建一个ChatOpen AI 对象，表示聊天模型
llm = OllamaLLM(model="llama3.2:3b")
# 【AI】创建一个AI问答链
question_answer_chain = create_stuff_documents_chain(llm, prompt)
# 【RAG】创建一个检索链，将检索器和问答链结合[向量数据库]
rag_chain = create_retrieval_chain(retriever, question_answer_chain)

# 定义上下文化问题的系统提示词
contextualize_q_system_prompt = (
    "给定聊天历史和最新的用户问题。"
    "该问题可能引用聊天历史中的上下文。"
    "重新构造一个可以在没有聊天历史的情况下理解的独立问题。"
    "如果需要，不要回答问题，只需要重新构造问题并返回。"
)
# 定义上下文化问题的提示词模版
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
"""" [新组件 ]创建一个带有历史记录感知的检索器 """
history_aware_retriever = create_history_aware_retriever(
    llm, retriever, contextualize_q_prompt
)
# 创建一个带有聊天历史记录的问答链
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
# 创建一个带有历史记录感知的检索链
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

# 创建一个字典，用户存储聊天历史记录
store = {}


# 定义一个函数，用于获取指定会话的聊天历史记录
def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = InMemoryChatMessageHistory()
    return store[session_id]


# 创建一个RunnableWithMessageHistory对象，用于管理有状态的聊天历史记录
conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

# 调用有状态的检索链，获取问答
response = conversational_rag_chain.invoke(
    input={"input": "什么是任务分解？"},
    config={"configurable": {"session_id": "abc123"}},  # 在store中构建一个键为“abc123”的键
)["answer"]
print(response)
## 再次调用有状态的检索链，获取另一个回答
response2 = conversational_rag_chain.invoke(
    input={"input": "我刚刚问了什么？"},
    config={"configurable": {"session_id": "abc123"}},
)["answer"]
print(response2)

## 再次调用有状态的检索链，换个session_id
response3 = conversational_rag_chain.invoke(
    input={"input": "我刚刚问了什么？"},
    config={"configurable": {"session_id": "abc456"}},
)["answer"]
print(response3)

# 打印存储在会话“abc123”中的所有消息
for message in store["abc123"].messages:
    if isinstance(message, AIMessage):
        prefix = "AI"
    else:
        prefix = "User"
    print(f"\n{prefix}: {message.text}")
